{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": true,
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "# Multidimensional Item Response Theory with Gradient Descent Optimization (GDIRT)\n",
    "\n",
    "This notebook will show you how to train and use the MIRT.\n",
    "First, we will show how to get the data (here we use a0910 as the dataset).\n",
    "Then we will show how to train a MIRT and perform the parameters persistence.\n",
    "At last, we will show how to load the parameters from the file and evaluate on the test dataset.\n",
    "\n",
    "The script version could be found in [MIRT.py](MIRT.py)"
   ]
  },
  {
   "cell_type": "markdown",
   "source": [
    "## Data Preparation\n",
    "\n",
    "Before we process the data, we need to first acquire the dataset which is shown in [prepare_dataset.ipynb](prepare_dataset.ipynb)"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "outputs": [
    {
     "data": {
      "text/plain": "   user_id  item_id  score\n0     1615    12977      1\n1      782    13124      0\n2     1084    16475      0\n3      593     8690      0\n4      127    14225      1",
      "text/html": "<div>\n<style scoped>\n    .dataframe tbody tr th:only-of-type {\n        vertical-align: middle;\n    }\n\n    .dataframe tbody tr th {\n        vertical-align: top;\n    }\n\n    .dataframe thead th {\n        text-align: right;\n    }\n</style>\n<table border=\"1\" class=\"dataframe\">\n  <thead>\n    <tr style=\"text-align: right;\">\n      <th></th>\n      <th>user_id</th>\n      <th>item_id</th>\n      <th>score</th>\n    </tr>\n  </thead>\n  <tbody>\n    <tr>\n      <th>0</th>\n      <td>1615</td>\n      <td>12977</td>\n      <td>1</td>\n    </tr>\n    <tr>\n      <th>1</th>\n      <td>782</td>\n      <td>13124</td>\n      <td>0</td>\n    </tr>\n    <tr>\n      <th>2</th>\n      <td>1084</td>\n      <td>16475</td>\n      <td>0</td>\n    </tr>\n    <tr>\n      <th>3</th>\n      <td>593</td>\n      <td>8690</td>\n      <td>0</td>\n    </tr>\n    <tr>\n      <th>4</th>\n      <td>127</td>\n      <td>14225</td>\n      <td>1</td>\n    </tr>\n  </tbody>\n</table>\n</div>"
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Load the data from files\n",
    "import pandas as pd\n",
    "\n",
    "train_data = pd.read_csv(\"../../data/a0910/train.csv\")\n",
    "valid_data = pd.read_csv(\"../../data/a0910/valid.csv\")\n",
    "test_data = pd.read_csv(\"../../data/a0910/test.csv\")\n",
    "\n",
    "train_data.head(5)"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "outputs": [
    {
     "data": {
      "text/plain": "(186049, 25606, 55760)"
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(train_data), len(valid_data), len(test_data)"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "outputs": [
    {
     "data": {
      "text/plain": "(<torch.utils.data.dataloader.DataLoader at 0x213d0517e80>,\n <torch.utils.data.dataloader.DataLoader at 0x213d051a5e0>,\n <torch.utils.data.dataloader.DataLoader at 0x213d051a9d0>)"
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Transform data to torch Dataloader (i.e., batchify)\n",
    "# batch_size is set to 32\n",
    "\n",
    "import torch\n",
    "from torch.utils.data import TensorDataset, DataLoader\n",
    "\n",
    "batch_size = 256\n",
    "def transform(x, y, z, batch_size, **params):\n",
    "    dataset = TensorDataset(\n",
    "        torch.tensor(x, dtype=torch.int64),\n",
    "        torch.tensor(y, dtype=torch.int64),\n",
    "        torch.tensor(z, dtype=torch.float)\n",
    "    )\n",
    "    return DataLoader(dataset, batch_size=batch_size, **params)\n",
    "\n",
    "train, valid, test = [\n",
    "    transform(data[\"user_id\"], data[\"item_id\"], data[\"score\"], batch_size)\n",
    "    for data in [train_data, valid_data, test_data]\n",
    "]\n",
    "train, valid, test"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "## Training and Persistence"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "outputs": [],
   "source": [
    "import logging\n",
    "logging.getLogger().setLevel(logging.INFO)"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 0: 100%|██████████| 727/727 [00:17<00:00, 41.54it/s]\n",
      "evaluating: 100%|██████████| 101/101 [00:00<00:00, 298.14it/s]\n",
      "Epoch 1: 100%|██████████| 727/727 [00:20<00:00, 35.42it/s]\n",
      "evaluating: 100%|██████████| 101/101 [00:00<00:00, 279.76it/s]\n",
      "INFO:root:save parameters to mirt.params\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[Epoch 0] LogisticLoss: 0.669528\n",
      "[Epoch 0] auc: 0.503206, accuracy: 0.659845\n",
      "[Epoch 1] LogisticLoss: 0.643967\n",
      "[Epoch 1] auc: 0.503432, accuracy: 0.659845\n"
     ]
    }
   ],
   "source": [
    "from EduCDM import MIRT\n",
    "\n",
    "cdm = MIRT(4164, 17747, 123)\n",
    "\n",
    "cdm.train(train, valid, epoch=2)\n",
    "cdm.save(\"mirt.params\")"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "## Loading and Testing"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:root:load parameters from mirt.params\n",
      "evaluating: 100%|██████████| 218/218 [00:00<00:00, 263.71it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "auc: 0.496306, accuracy: 0.654322\n"
     ]
    }
   ],
   "source": [
    "cdm.load(\"mirt.params\")\n",
    "auc, accuracy = cdm.eval(test)\n",
    "print(\"auc: %.6f, accuracy: %.6f\" % (auc, accuracy))"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}